"""Gist compression demo."""

from typing import Optional

import fire
import torch
from transformers import AutoConfig, AutoTokenizer, LlamaTokenizer

from . import gist_llama, gist_t5, weight_diff
from .gist_llama import GistLlamaForCausalLM
from .gist_t5 import GistT5ForConditionalGeneration


def humanbytes(B):
    """Return the given bytes as a human friendly KB, MB, GB, or TB string.

    https://stackoverflow.com/a/31631711/2980246
    """
    B = float(B)
    KB = float(1024)
    MB = float(KB**2)  # 1,048,576
    GB = float(KB**3)  # 1,073,741,824
    TB = float(KB**4)  # 1,099,511,627,776

    if B < KB:
        return "{0} {1}".format(B, "Bytes" if 0 == B > 1 else "Byte")
    elif KB <= B < MB:
        return "{0:.2f} KB".format(B / KB)
    elif MB <= B < GB:
        return "{0:.2f} MB".format(B / MB)
    elif GB <= B < TB:
        return "{0:.2f} GB".format(B / GB)
    elif TB <= B:
        return "{0:.2f} TB".format(B / TB)


@torch.inference_mode()
def main(
    model_name_or_path: str,
    input: str = "",
    num_gist_tokens: Optional[int] = 1,
    cache_dir: str = ".cache",
    precision: str = "fp32",
    max_new_tokens: int = 512,
    base_llama_path: Optional[str] = None,
) -> None:
    """Decode from a model with gist compression.

    Args:
        model_name_or_path: The model to load. MUST BE A GIST MODEL.
        instruction: The instruction to be compressed (required).
        input: The input for the instruction (optional). Will not be compressed
            or cached.
        num_gist_tokens: number of gist tokens to compress to. This should
            match the number of gist tokens the model was trained on.
        cache_dir: Hugging Face cache dir.
        precision: Precision to load the model in. Recommend fp32 or bf16 to
            save space (not fp16).
        max_new_tokens: Maximum number of new tokens to decode.
        base_llama_path: Any LLaMA model loaded from Hugging Face
            (jayelm/llama-7b-{gist,pos_control,neg_control}-1) is a weight
            diff, not the full model. If loading one of the Hugging Face LLaMA
            models, use this argument to specify the path to the raw LLaMA model.
    """
    is_llama = "llama" in model_name_or_path.lower()
    is_t5 = "t5" in model_name_or_path.lower()

    # Load config
    config = AutoConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir)

    # Load model
    print(f"Loading model {model_name_or_path}")
    if is_t5:
        model_cls = GistT5ForConditionalGeneration
    elif is_llama:
        model_cls = GistLlamaForCausalLM
    else:
        raise ValueError(f"Model type {model_name_or_path} not supported")

    if model_name_or_path in {
        "jayelm/llama-7b-gist-1",
        "jayelm/llama-7b-pos_control-1",
        "jayelm/llama-7b-neg_control-1",
    }:
        # Load with weight diff file
        if base_llama_path is None:
            raise ValueError(
                f"{model_name_or_path} is a weight diff huggingface repo. "
                "You must specify a `base_llama_path` for this to work."
            )
        else:
            print("Weight diff detected. Applying to original model...")
        model, _ = weight_diff.recover(
            path_raw=base_llama_path,
            path_diff=model_name_or_path,
            test_inference=False,
            cache_dir=cache_dir,
        )
    else:
        model = model_cls.from_pretrained(
            model_name_or_path,
            config=config,
            cache_dir=cache_dir,
        )

    dtypes = {
        "bf16": torch.bfloat16,
        "fp16": torch.float16,
        "fp32": torch.float,
    }
    model = model.to(dtypes[precision]).cuda().eval()

    # Load tokenizer. It must already have gist token defined.
    print("Loading tokenizer")
    if is_llama:
        tokenizer = LlamaTokenizer.from_pretrained(model_name_or_path)
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = "left"
        assert len(tokenizer) == gist_llama.PRETRAINED_VOCAB_SIZE + 1
        assert model.lm_head.weight.shape[0] == gist_llama.PRETRAINED_VOCAB_SIZE + 1
    else:
        tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        assert len(tokenizer) == gist_t5.PRETRAINED_VOCAB_SIZE + 1
        assert model.shared.weight.shape[0] == gist_t5.PRETRAINED_VOCAB_SIZE + 1
    gist_token = tokenizer.additional_special_tokens_ids[-1]
    print(gist_token)

    hs, k, v = [], [], []
    instructions = ["Computer Science", "Biology", "Chemistry", "Mathematics", "English", "French"]
    for instruction in instructions:
        # Compress instruction
        # print("Compressing instruction")
        gist_str = "<GIST>" * num_gist_tokens
        prepped_instruction = f"Instruction: {instruction}\n{gist_str}"
        instruction_input_ids = tokenizer.encode(prepped_instruction)
        print(instruction, instruction_input_ids)
        if is_t5:
            instruction_input_ids = instruction_input_ids[:-1]  # Remove eos token
        instruction_input_ids_tensor = (
            torch.tensor(instruction_input_ids).unsqueeze(0).cuda()
        )
        gist_kwargs = {
            "input_ids": instruction_input_ids_tensor,
            "attention_mask": torch.ones_like(instruction_input_ids_tensor),
        }
        if is_llama:
            gist_kwargs["attention_mask_gist"] = torch.ones_like(
                instruction_input_ids_tensor
            )[None, None]
        gist_activations = model.get_gist_activations(
            gist_token=gist_token,
            num_gist_tokens=num_gist_tokens,
            **gist_kwargs,
        )

        last_hidden_state = gist_activations.last_hidden_state
        past_key_values = gist_activations.past_key_values

        hs.append(last_hidden_state.detach().cpu().numpy().flatten())
        k.append(past_key_values[-1][0].detach().cpu().numpy().flatten())
        v.append(past_key_values[-1][1].detach().cpu().numpy().flatten())
    
    from scipy.spatial.distance import pdist, squareform

    for dist in [ 'cosine']: #'euclidean',
        hs_dm = pdist(hs, dist)
        k_dm = pdist(k, dist)
        v_dm = pdist(v, dist)
        print(dist, hs_dm, k_dm, v_dm)




if __name__ == "__main__":
    fire.Fire(main)


